[quantization] Introduce wrapper for Qwen3VLTextRotaryEmbedding#498
[quantization] Introduce wrapper for Qwen3VLTextRotaryEmbedding#498dvsav wants to merge 1 commit intoSamsung:mainfrom
Conversation
36c7773 to
d2f6509
Compare
For ReviewersBelow is the source code of # transformers/models/qwen3_vl/modeling_qwen3_vl.py
class Qwen3VLTextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Qwen3VLTextConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
@staticmethod
def compute_default_rope_parameters(
config: Qwen3VLTextConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
# In contrast to other models, Qwen3VL has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THWTHWTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t |
This change introduces QuantQwen3VLTextRotaryEmbedding wrapper to support post-training quantization of Qwen3VLTextRotaryEmbedding module. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
d2f6509 to
6a07e4b
Compare
| def apply_interleaved_mrope(self, freqs, mrope_section): | ||
| """ | ||
| Apply interleaved MRoPE to 3D rotary embeddings. | ||
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to | ||
| interleaved [THWTHWTHW...TT], preserving frequency continuity. | ||
|
|
||
| Args: | ||
| freqs: (3, bs, seq_len, head_dim // 2) | ||
| mrope_section: (3,) | ||
|
|
||
| Returns: | ||
| freqs_t: (bs, seq_len, head_dim // 2) | ||
|
|
||
| Design Note: | ||
| This implementation is using slice_copy, index_select, and cat | ||
| to avoid yet unsupported slice_scatter with step=3 operation and | ||
| to avoid unsupported in-place operator index_put.default. | ||
| """ | ||
| # Start with T dimension (will keep some, replace some) | ||
| freqs_t_base = freqs[0] | ||
|
|
||
| # For each dimension (H, W), extract frequency bands to be interleaved | ||
| h_w_bands = [] | ||
|
|
||
| for dim, offset in enumerate((1, 2), start=1): # H, W dimensions | ||
| length = mrope_section[dim] * 3 | ||
| indices = torch.arange(offset, length, 3, device=freqs.device) | ||
|
|
||
| # Select frequency bands from H/W dimensions | ||
| # freqs[dim] has shape (bs, seq_len, head_dim//2) | ||
| # index_select on last dim: (bs, seq_len, num_selected) | ||
| freqs_bands = freqs[dim].index_select(dim=-1, index=indices) | ||
| h_w_bands.append(freqs_bands) | ||
|
|
||
| # Now we need to build the interleaved output | ||
| # Original T dimension has indices 0-63 | ||
| # We want to replace specific indices with H/W bands | ||
|
|
||
| # The interleaving pattern: T0, H1, W2, T3, T4, H5, W6, T7, ... | ||
| # Where T, H, W bands follow the pattern from mrope_section | ||
|
|
||
| # Build the output by slicing and concatenating | ||
| # Strategy: Slice T dimension into chunks, insert H/W bands, concatenate | ||
|
|
||
| chunks = [] | ||
| pos = 0 | ||
|
|
||
| # Total length in the last dimension | ||
| total_len = freqs_t_base.shape[-1] | ||
|
|
||
| for i in range(total_len): | ||
| # Determine which dimension this position belongs to | ||
| # Pattern: T, H, W, T, T, H, W, T, ... | ||
| mod = i % 3 | ||
|
|
||
| if mod == 0: | ||
| # T dimension position - take from T | ||
| # Slice just this index from T | ||
| chunk = freqs_t_base[..., i : i + 1] | ||
| chunks.append(chunk) | ||
| elif mod == 1: | ||
| # H dimension position - take from H | ||
| # Calculate which band this is | ||
| band_idx = (i - 1) // 3 | ||
| if band_idx < h_w_bands[0].shape[-1]: | ||
| chunk = h_w_bands[0][..., band_idx : band_idx + 1] | ||
| chunks.append(chunk) | ||
| else: | ||
| # Fallback to T if out of bounds | ||
| chunk = freqs_t_base[..., i : i + 1] | ||
| chunks.append(chunk) | ||
| else: # mod == 2 | ||
| # W dimension position - take from W | ||
| band_idx = (i - 2) // 3 | ||
| if band_idx < h_w_bands[1].shape[-1]: | ||
| chunk = h_w_bands[1][..., band_idx : band_idx + 1] | ||
| chunks.append(chunk) | ||
| else: | ||
| # Fallback to T if out of bounds | ||
| chunk = freqs_t_base[..., i : i + 1] | ||
| chunks.append(chunk) | ||
|
|
||
| # Concatenate all chunks | ||
| freqs_t = torch.cat(chunks, dim=-1) | ||
|
|
||
| return freqs_t |
There was a problem hiding this comment.
Note for Reviewers
Trying to replicate the original implementation Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope leads to errors at the time of conversion to Circle.
The original implementation uses slice(offset, length, 3) that emits slice_scatter operator with step=3 when the model is exported.
When it's being converted to Circle DecomposeSliceScatter pass fails with the following error: RuntimeError: slice_scatter with step > 1 is not yet supported. Node: slice_scatter.
Approaches leveraging in-place operations don't work either.
For example:
# Create list of indices manually (avoid step=3 in slice)
idx = list(range(offset, length, 3))
# Extract and copy using cat
freqs_t[..., idx] = freqs[dim, ..., idx]
This approach fails because it emits index_put operator that isn't supported in Circle (tico/utils/convert.py raises tico.utils.errors.NotYetSupportedError: NOT SUPPORTED OPERATOR IN GRAPH MODULE).
The same goes for the following example (using index_copy_ that generates index_put as well):
# Create tensor of indices using torch.arange
# This avoids Python list which causes index_put
indices = torch.arange(offset, length, 3, device=freqs.device)
# Select from source dimension
# freqs has shape (3, batch, seq_len, head_dim//2)
# Select all batch and seq dims, only specific indices in head_dim dim
src_selected = freqs[dim].index_select(dim=-1, index=indices)
# Copy to target using index_copy_ (which is supported)
freqs_t.index_copy_(dim=-1, index=indices, source=src_selected)
The only solution that worked was a pure functional approach using basic tensor slicing (converted to slice_copy operators during torch.export), index_select, and cat.
Why in-place operations fail: any in-place tensor update (like tensor[...] = value, index_copy_, or index_put_) during torch.export gets traced to an operator that mutates tensor memory. Circle (as well as TFLite) runtime model is designed for functional computations without in-place mutations, so these operators are not supported.
The functional approach works because it:
- Builds intermediate tensors via
slice_copyandindex_select(read-only operations). - Combines them via
cat(creates a new tensor, doesn't modify existing ones). - Never mutates tensors in-place, thus avoiding unsupported operators.
This change introduces
QuantQwen3VLTextRotaryEmbeddingwrapper to support post-training quantization ofQwen3VLTextRotaryEmbeddingmodule.Why?
Qwen3VLTextRotaryEmbeddingmodule is used in the language model of Qwen.Trying to quantize
Qwen3VLTextRotaryEmbeddingvia PTQ generates exceptionPTQQuantizer: no quantization wrapper for Qwen3VLTextRotaryEmbedding.What
This change introduces:
QuantQwen3VLTextRotaryEmbedding(tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py).class TestQuantQwen3VLTextRotaryEmbedding(test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py) - skipped iftransformerspackage is not installed.tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embeddingin_CORE_MODULES(tico/quantization/wrapq/wrappers/registry.py).Qwen3VLTextRotaryEmbeddingquantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_text_rotary_embedding.py).Unit Tests
Unit tests results with coverage information: